import os
import torch.nn as nn

root = os.getcwd()

# 沟槽微结构图像路径
img_root = root + r'\data\image'

# 激光加工参数路径
txt_root = root + r'\data\txt'

# 模型检查点权重参数路径
checkpoint_root = root + r'\checkpoint'

# 最优模型参数保存路径
weight_root = root + r'\weight'

# 接触角预测误差值保存路径
csv_root = root + r'\csv'

# 损失函数
Loss_function = {
    'L1': nn.L1Loss(),
    'L1SMOOTH': nn.SmoothL1Loss(),
    'MSE': nn.MSELoss()
}

# 创建缺失的项目文件
if not os.path.isdir(img_root) and os.path.isdir(txt_root) and os.path.isdir(checkpoint_root) and os.path.isdir(weight_root) and os.path.isdir(csv_root):
    os.makedirs(img_root)
    os.makedirs(txt_root)
    os.makedirs(checkpoint_root)
    os.makedirs(weight_root)
    os.makedirs(csv_root)
